import PIL.Image as Image
import torch
import random
from torch.utils.data import Dataset
import numpy as np
import re
from tqdm import tqdm
import os
import cv2 as cv

def normalize_image(image, mean=np.array([0.485, 0.456, 0.406]), std=1.):
    image = image.astype(np.float32)
    image = image / 255.0
    return image

class ImageDataset(Dataset):
    def __init__(self, file_paths, obj_data, frame_stack=3, split='train', transform=normalize_image, ret_frame_info=False, use_flow=False, fit_linear=False):
        self.transform = transform
        self.split = split
        self.obj_data = obj_data
        self.frame_stack = frame_stack
        self.ret_frame_info = ret_frame_info
        self.use_flow = use_flow
        self.fit_linear = fit_linear

        # Split the file paths into train and test sets
        random.shuffle(file_paths)
        cutoff = int(len(file_paths)*0.8)
        train_paths, test_paths = file_paths[:cutoff], file_paths[cutoff:]

        if self.split == 'train':
            self.file_paths = train_paths
        elif self.split == 'val':
            self.file_paths = test_paths
        elif self.split == 'full':
            self.file_paths = file_paths
        else:
            raise ValueError("Invalid split value. Must be either 'train' or 'val'.")
        print(" ---- ImageDataset ---- ")
        print("total images: ", len(file_paths))
        print("train images: ", len(train_paths))

        # Preprocess and store stacked frame paths
        self.stacked_frame_paths = [self.preprocess_stacked_frames(path, i) for i, path in tqdm(enumerate(self.file_paths), desc="Preprocessing Stacked Frames")]

    def preprocess_stacked_frames(self, current_image_path, idx):
        frame_number = int(re.search(r'state(\d+)_', current_image_path).group(1))
        base_folder = os.path.dirname(os.path.dirname(current_image_path))
        valid_frame_paths = [current_image_path]
        

        # Check previous frames and add them or pad with the most recent frame
        done_flag = self.obj_data[frame_number]['Done']
        skip = 1
        for i in range(skip, self.frame_stack, skip):
            prev_frame_number = frame_number - i
            folder_number = prev_frame_number // 2000
            prev_folder_path = os.path.join(base_folder, str(folder_number))

            if prev_frame_number >= 0 and not done_flag:
                prev_file_name = re.sub(r'state\d+_', f'state{prev_frame_number}_', os.path.basename(current_image_path))
                prev_frame_path = os.path.join(prev_folder_path, prev_file_name)
                valid_frame_paths.insert(0, prev_frame_path)
                done_flag = self.obj_data[prev_frame_number]['Done']
            else:
                # Pad with the most recent valid frame
                valid_frame_paths.insert(0, valid_frame_paths[0])

        return valid_frame_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # Use preprocessed stacked frame paths
        valid_frame_paths = self.stacked_frame_paths[idx]

        # Load and stack images
        images = [self.load_image(path) for path in valid_frame_paths]
        if self.use_flow:
            prev_image = images[0]
            cur_image = images[-1]
            prev_gray = cv.cvtColor(prev_image, cv.COLOR_BGR2GRAY)
            gray = cv.cvtColor(cur_image, cv.COLOR_BGR2GRAY)
            flow = cv.calcOpticalFlowFarneback(prev_gray, gray,  None, pyr_scale=0.5, levels=5, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
            if self.transform:
                cur_image = self.transform(cur_image)
            stacked_image = np.concatenate([flow, cur_image], axis=2)
        else:
            stacked_image = np.concatenate(images, axis=2)
            if self.transform:
                stacked_image = self.transform(stacked_image)
        image_torch = torch.from_numpy(stacked_image.transpose(2, 0, 1)).float()

        if self.ret_frame_info:
            state_path = valid_frame_paths[-1]
            frame_number = int(re.search(r'state(\d+)_', state_path).group(1))
            obj_name = re.search(r'state\d+_(.+)\.png', state_path).group(1)
            return {'image': image_torch, 'frame_number': frame_number, 'obj_name': obj_name}
        
        if self.fit_linear:
            state_path = valid_frame_paths[-1]
            frame_number = int(re.search(r'state(\d+)_', state_path).group(1))
            obj_name = re.search(r'state\d+_(.+)\.png', state_path).group(1)
            target = torch.FloatTensor(self.obj_data[frame_number][obj_name][:4])
            return image_torch, target

        return image_torch

    def load_image(self, image_path):

        image = Image.open(image_path)
        image = image.resize((64, 64))
        # image.save("test_image.png")
        image = np.array(image)

        return image